{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4cbc1487",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CUDA_VISIBLE_DEVICES=\"2,3\" vllm/bin/vllm serve \"mesolitica/Malaysian-Qwen2.5-7B-Dialect-Reasoning-GRPO\" --port 8006 --tensor_parallel_size=2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0ab3bcb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "792e2282",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "dataset = load_dataset('huseinzol05/malaysian-dialect-qa', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "868eeb43",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_lang = load_dataset('huseinzol05/malaysian-dialect-qa-lang', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b376febd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "questions = []\n",
    "for i in range(len(dataset)):\n",
    "    q = dataset[i]['question']\n",
    "    questions.append((i, q))\n",
    "    \n",
    "len(questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3eddca6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "folder = 'Malaysian-Qwen2.5-7B-Reasoning-GRPO-fp16'\n",
    "# !rm -rf {folder}\n",
    "!mkdir {folder}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f785ac6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import os\n",
    "import json\n",
    "import re\n",
    "\n",
    "def generate_answer(row, repeat = 5):\n",
    "    no, q = row\n",
    "    for k in range(repeat):\n",
    "        filename = os.path.join(folder, f'{no}-{k}.json')\n",
    "        try:\n",
    "            with open(filename) as fopen:\n",
    "                json.load(fopen)\n",
    "            continue\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "        json_data = {\n",
    "            'model': \"mesolitica/Malaysian-Qwen2.5-7B-Dialect-Reasoning-GRPO\",\n",
    "            'messages': [\n",
    "                {'role': 'system', 'content': 'You are going to enter reasoning mode. First, you try to think step-by-step in Malay. After that, put your final answer within $\\\\boxed{}$.'},\n",
    "                {'role': 'user', 'content': q},\n",
    "            ],\n",
    "            'max_tokens': 24000,\n",
    "        }\n",
    "        \n",
    "        while True:\n",
    "            response = requests.post('http://localhost:8006/v1/chat/completions', json=json_data)\n",
    "            r = response.json()['choices'][0]['message']['content'].strip()\n",
    "            answers = re.findall(r\"\\$boxed\\{(.*?)\\}\\$\", r)\n",
    "            if len(answers) == 1:\n",
    "                a = answers[0]\n",
    "                with open(filename, 'w') as fopen:\n",
    "                    json.dump(a, fopen)\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4cd42250",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_answer(questions[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "dcaf4f5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def consumer(queue, name):\n",
    "    while True:\n",
    "        if queue.qsize() == 0:\n",
    "            break\n",
    "        item = queue.get()\n",
    "        generate_answer(item)\n",
    "    print(f'consumer {name} done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c8d11030",
   "metadata": {},
   "outputs": [],
   "source": [
    "from threading import Thread\n",
    "from queue import Queue\n",
    "\n",
    "queue = Queue()\n",
    "for u in questions:\n",
    "    queue.put(u)\n",
    "    \n",
    "ori_size = queue.qsize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "42bec1f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎| 139/140 [02:32<00:01,  1.10s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "consumer 29 done\n",
      "consumer 48 done\n",
      "consumer 35 done\n",
      "consumer 0 done\n",
      "consumer 6 done\n",
      "consumer 36 done\n",
      "consumer 9 done\n",
      "consumer 15 done\n",
      "consumer 33 done\n",
      "consumer 34 done\n",
      "consumer 22 done\n",
      "consumer 32 done\n",
      "consumer 37 done\n",
      "consumer 26 done\n",
      "consumer 1 done\n",
      "consumer 42 done\n",
      "consumer 47 done\n",
      "consumer 31 done\n",
      "consumer 19 done\n",
      "consumer 21 done\n",
      "consumer 3 done\n",
      "consumer 41 done\n",
      "consumer 23 done\n",
      "consumer 44 done\n",
      "consumer 49 done\n",
      "consumer 17 done\n",
      "consumer 20 done\n",
      "consumer 30 done\n",
      "consumer 13 done\n",
      "consumer 25 done\n",
      "consumer 16 done\n",
      "consumer 46 done\n",
      "consumer 45 done\n",
      "consumer 24 done\n",
      "consumer 40 done\n",
      "consumer 7 done\n",
      "consumer 10 done\n",
      "consumer 2 done\n",
      "consumer 11 done\n",
      "consumer 12 done\n",
      "consumer 5 done\n",
      "consumer 39 done\n",
      "consumer 4 done\n",
      "consumer 28 done\n",
      "consumer 43 done\n",
      "consumer 38 done\n",
      "consumer 18 done\n",
      "consumer 14 done\n",
      "consumer 27 done\n",
      "consumer 8 done\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "max_worker = 50\n",
    "consumers = [Thread(target=consumer, args=(queue,i)) for i in range(max_worker)]\n",
    "for i in range(len(consumers)):\n",
    "    consumers[i].start()\n",
    "    \n",
    "pbar = tqdm(total=ori_size)\n",
    "last_size = 0\n",
    "while True:\n",
    "    size = queue.qsize()\n",
    "    if size == 0:\n",
    "        break\n",
    "    left = ori_size - size\n",
    "    minus = left - last_size\n",
    "    if minus > 0:\n",
    "        pbar.update(minus)\n",
    "        last_size += minus\n",
    "\n",
    "pbar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a0be872e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sacrebleu.metrics import CHRF\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7818b54c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 721.17it/s]\n"
     ]
    }
   ],
   "source": [
    "chrf = CHRF()\n",
    "pairs = defaultdict(list)\n",
    "\n",
    "for i in tqdm(range(len(dataset_lang))):\n",
    "    from_lang = dataset_lang[i]['from_lang']\n",
    "    to_lang = dataset_lang[i]['to_lang']\n",
    "    gt = dataset_lang[i]['answer']\n",
    "    pair = f'{from_lang}<>{to_lang}'\n",
    "    files = glob(f'Malaysian-Qwen2.5-7B-Reasoning-GRPO-fp16/{i}-*.json')\n",
    "    if len(files) < 5:\n",
    "        print(i)\n",
    "    scores = []\n",
    "    for f in files:\n",
    "        with open(f) as fopen:\n",
    "            d = json.load(fopen)\n",
    "        score = chrf.corpus_score([d], [[gt]]).score\n",
    "        scores.append(score)\n",
    "\n",
    "    max_score = max(scores)\n",
    "    pairs[pair].append(max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f2247f18",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: johor To: malay, score: 57.42949426937456\n",
      "From: kedah To: malay, score: 58.12580212528728\n",
      "From: pahang To: malay, score: 55.60484906845884\n",
      "From: negeri sembilan To: malay, score: 56.4509629484568\n",
      "From: kelantan To: malay, score: 53.944979416369996\n",
      "From: penang To: malay, score: 62.20935643642939\n",
      "From: melaka To: malay, score: 57.14492955494046\n",
      "From: malay To: johor, score: 55.68356840259747\n",
      "From: malay To: kedah, score: 56.264707994950186\n",
      "From: malay To: pahang, score: 60.15982036912563\n",
      "From: malay To: negeri sembilan, score: 48.71725827604103\n",
      "From: malay To: kelantan, score: 43.948995049469474\n",
      "From: malay To: penang, score: 63.15864675162173\n",
      "From: malay To: melaka, score: 74.12398375006538\n"
     ]
    }
   ],
   "source": [
    "for k, v in pairs.items():\n",
    "    l, r = k.split('<>')\n",
    "    print(f'From: {l} To: {r}, score:', np.mean(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "380970be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "57.27291054561676"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: johor To: malay, score: 57.42949426937456\n",
    "From: kedah To: malay, score: 58.12580212528728\n",
    "From: pahang To: malay, score: 55.60484906845884\n",
    "From: negeri sembilan To: malay, score: 56.4509629484568\n",
    "From: kelantan To: malay, score: 53.944979416369996\n",
    "From: penang To: malay, score: 62.20935643642939\n",
    "From: melaka To: malay, score: 57.14492955494046\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a2ff12e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "57.436711513410124"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: malay To: johor, score: 55.68356840259747\n",
    "From: malay To: kedah, score: 56.264707994950186\n",
    "From: malay To: pahang, score: 60.15982036912563\n",
    "From: malay To: negeri sembilan, score: 48.71725827604103\n",
    "From: malay To: kelantan, score: 43.948995049469474\n",
    "From: malay To: penang, score: 63.15864675162173\n",
    "From: malay To: melaka, score: 74.12398375006538\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fdbb96a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
